from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
from causally.model.utils import set_color


class AbstractRecommender(nn.Module):
    r"""Base class for all models"""

    def __init__(self,config, dataset):
        self.logger = getLogger()
        self.config = config
        self.dataset = dataset

        self.x_n_covariate = config['x_n_covariate']
        self.v_n_covariate = config['v_n_covariate']

        super(AbstractRecommender, self).__init__()

    def calculate_loss(self, x,t,y,v):

        raise NotImplementedError


    def other_parameter(self):
        if hasattr(self, "other_parameter_name"):
            return {key: getattr(self, key) for key in self.other_parameter_name}
        return dict()

    def load_other_parameter(self, para):
        if para is None:
            return
        for key, value in para.items():
            setattr(self, key, value)

    def __str__(self):
        """
        Model prints with number of trainable parameters
        """
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        return (
            super().__str__()
            + set_color("\nTrainable parameters", "blue")
            + f": {params}"
        )

